Skip to content

feat: Whisper prompting#22496

Merged
amyeroberts merged 35 commits intohuggingface:mainfrom
connor-henderson:whisper-prompting
May 19, 2023
Merged

feat: Whisper prompting#22496
amyeroberts merged 35 commits intohuggingface:mainfrom
connor-henderson:whisper-prompting

Conversation

@connor-henderson
Copy link
Contributor

@connor-henderson connor-henderson commented Mar 31, 2023

What does this PR do?

Closes #22395, thank you @sanchit-gandhi for the descriptive ask!

Note: due to initial scope expansion the commit history includes initial work towards condition_on_previous_text, always_use_initial_prompt, and pipeline integration, but these efforts have been pushed to a later PR

This this pull request adds 3 new functionalities + tests to support initial prompting functionality within Whisper's model.generate() and tokenizer:

  • prompt_ids param for model.generate():
    • Optional param of initial prompt ids to provide context for each chunk of text generated by in model.generate()
  • get_prompt_ids Processor method to create initial prompt ids to pass to generate from a passed in string
  • Removing the prompt when the tokenizer is decoding if skip_special_tokens=True

Example new API usage:

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
input_features = processor(input_speech, return_tensors="pt").input_features

# --- Without prompt ---
prompt_ids = processor.get_prompt_ids("Leighton")
output_without_prompt = model.generate(input_features)
print(processor.decode(output_without_prompt[0]))
# "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"

# --- With prompt ---
prompt_ids = processor.get_prompt_ids("Leighton")
output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)
print(processor.decode(output_with_prompt[0]))
# "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings. Haven't added anywhere outside of documenting the new generate() arg directly on the function
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 31, 2023

The documentation is not available anymore as the PR was closed or merged.

@connor-henderson connor-henderson marked this pull request as ready for review March 31, 2023 19:21
@hollance
Copy link
Contributor

hollance commented Apr 3, 2023

Hey this PR looks really good (although I'll leave the actual review to Sanchit or Arthur).

I was just wondering whether it also makes sense to support the condition_on_previous_text option that the OpenAI repo has, since that uses the same mechanism (using the <|startofprev|> token).

In addition, there's this PR that suggests an always_use_initial_prompt option that uses the prompt on every segment, not just the first. Might be useful to consider that here as well.

@connor-henderson
Copy link
Contributor Author

Hey this PR looks really good (although I'll leave the actual review to Sanchit or Arthur).

I was just wondering whether it also makes sense to support the condition_on_previous_text option that the OpenAI repo has, since that uses the same mechanism (using the <|startofprev|> token).

In addition, there's this PR that suggests an always_use_initial_prompt option that uses the prompt on every segment, not just the first. Might be useful to consider that here as well.

Hey Matthijs thanks, I'm happy to add what's wanted. Will look for HF guidance on that and whether it should be added here or in a follow on PR. temperature was another factor I saw in the Whisper model, if it was > 0.5 no prompt tokens were added (link).

Comment on lines 1618 to 1642
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing this with an actual conditional now. Any idea how the model test I added passed with this?

Copy link
Contributor Author

@connor-henderson connor-henderson Apr 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this last line handles all possible datatypes of token_ids, particularly int, torch.Tensor, and ndim > 1 np narrays. Maybe we should use to_py_obj above it first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's not needed to check for has_initial_prompt and simply always skip everything until the bos_token?

Copy link
Contributor Author

@connor-henderson connor-henderson Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good idea will use the bos_token instead of the prompt start token. Regarding always skipping, do we want to show the prompt when skip_special_tokens is False like in this example or no?

output = processor.batch_decode(pred_ids, skip_special_tokens=False)
# output: ['<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>']

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into this it appears the bos_token is <|endoftext|> unless otherwise set, which we couldn't use for slicing

Copy link
Contributor

@hollance hollance left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @connor-henderson, I was asked by @sanchit-gandhi to do a code review for your PR. It looks pretty good already, just needs a bit of fine-tuning. I'm only just getting familiar with the Whisper code myself, so my opinions don't necessarily always make sense. ;-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanchit suggested the argument name prompt_ids rather than initial_prompt_ids and I have to agree with that; the word prompt already implies that it precedes what happens.

I'm also wondering if this should be Optional[torch.Tensor] rather than List[int], just like input_ids in the HF NLP models. It feels "wrong" to use a list here (since we generally always use Tensors for tokenized input), even though it makes sense with how this variable is used. Maybe @sanchit-gandhi has an opinion about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, renamed to prompt_ids.

With regards to the type, I wonder if this is ok here because prompt_ids purpose is just to update forced_decoder_ids and decoder_start_token_id which are type int and List[List[int]] (if I'm not mistaken, saw that here)? If we used torch.Tensor, we would also have to add importing torch to the file with the get_prompt_ids function and I don't believe we'd ever require tensor functionality. Just lmk whichever you prefer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of window does this refer to? (I'm assuming what's called window here is what we call a chunk. If that's the case we should be consistent with the terminology.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "first window" refers to the initial segment of the input speech, so I believe it can be used interchangeably with chunk yes. Updated the wording to use 'chunk'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: The forced_decoder_ids use a tuple instead of a list, so this should be a tuple too:

indexed_initial_prompt_ids = [(rank + 1, token) for rank, token in enumerate(initial_prompt_ids)]

I also used rank instead of idx and token instead of id, to be consistent with tokenizer.get_decoder_prompt_ids.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also not 100% sure that generation_config.forced_decoder_ids is always filled in here. For example, if we do the following:

forced_decoder_ids = processor.get_decoder_prompt_ids(language="de", task="translate")
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)

then generation_config.forced_decoder_ids may not have the <|de|> token at this point. In this situation, the correct forced_decoder_ids will be filled in by super().generate(...), as far as I can tell.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch! Yes before the implementation was overriding provided forced_decoder_ids, but this should be fixed now. Before it was adding these from the generation_config only, but now checks kwargs first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer get_prompt_ids() as the name for this method.

Not sure that it needs to live in the tokenizer (since that involves duplicate code for the fast tokenizer). Perhaps just having it in the processor is enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool will rename and move it to just being on the processor. My thinking was that users would also want to access the method on the tokenizers directly, but I wasn't sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this could be an instance variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we've updated this decode function with your above suggestion this line was removed. This same logic now is only used once and its on the processor. Would you still like me to make it an instance var there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's just used once then it's not really worth making it an instance variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's not needed to check for has_initial_prompt and simply always skip everything until the bos_token?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for doing " " + text.strip()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what was done in the Whisper implementation. I believe the model formatting expects white spaces will be removed from the beginning and at the end of the string, with one space added at the start

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo in the function name. :-) But excellent work on adding these unit tests!

@connor-henderson
Copy link
Contributor Author

connor-henderson commented Apr 5, 2023

To-do list before re-requesting review

  • Converting the prompt token to an ID in an instance variable gives an incorrect ID, unlike when its called in decode
    --Given we're only using it in two places and it's an inexpensive op to call convert_tokens_to_ids I've left this, at least for now, to focus more on the below
  • Bug I found where if the ending text of the prompt matches the start of the transcribed text, that text will not be included in the transcription output. Example:
    --I'm actually not sure this is a bug now. The model has learned to be penalized for repeating itself and this only happens if the end of the prompt matches the beginning of the transcription almost exactly. It also appears to be happening inside the model itself as opposed to in the logits processing or other modification before / after.

Screenshot 2023-04-05 at 1 14 03 AM

Added from @hollance's below two comments:

  • Add always_use_initial_prompt and condition_on_previous_text options to pipeline and model.generate()
  • Add prompting functionality to the automatic-speech-recognition pipeline

@hollance
Copy link
Contributor

hollance commented Apr 5, 2023

One more thing we'll need to do, is change the automatic-speech-recognition pipeline so that it will actually call model.generate() with the prompt, but only for the first chunk (or always if we also decide to support an always_use_initial_prompt option). This logic cannot be part of the modeling code, as model.generate() has no knowledge of which chunk of audio it's processing.

@hollance
Copy link
Contributor

hollance commented Apr 5, 2023

I looked a bit more into how this works today, and it turns out that 🤗 Transformers does things a bit differently than the original OpenAI code.

OpenAI does the following:

For the first 30-second chunk of audio, it passes the following token sequence to the model's decoder on the first iteration: <|startofprev|> initial prompt<|startoftranscript|><|en|><|transcribe|>. And then it decodes the rest of the sequence autoregressively.

Then for the second chunk of audio, it passes the following sequence to the decoder on the first iteration: <|startofprev|> initial prompt output of the first chunk<|startoftranscript|><|en|><|transcribe|>.

For the next chunk, it uses <|startofprev|> initial prompt output of the first chunk output of the second chunk<|startoftranscript|><|en|><|transcribe|>

And so on... This list of tokens that it passes in the <|startofprev|> section grows longer and longer with each new chunk.

(When you set the condition_on_previous_text option to False, it only uses the output from the previous chunk instead of the complete history. In that case the initial prompt text is only used for the very first chunk.)

Our ASR pipeline works quite differently. It also splits up the audio in 30-second chunks but they partially overlap, and then it runs the model on these chunks in parallel. That makes it impossible to pass the previous context to these chunks, as each chunk is processed independently. So we have no way of sending <|startofprev|> initial prompt output of the first chunk<|startoftranscript|><|en|><|transcribe|> to the second chunk.

The best we can do is send <|startofprev|> initial prompt<|startoftranscript|><|en|><|transcribe|> to the very first chunk only, or always send it to all chunks. So we ignore the "previous context" part and always include the prompt. (The latter would do the same as this open PR on the OpenAI repo for always passing the initial prompt inside <|startofprev|> instead of the previous context.)

The suggested modifications to model.generate() in this PR make it possible to have both initial_prompt and the condition_on_previous_text options as in OpenAI, but it would require the user to write their own processing loop to get the same results as OpenAI. So we should definitely continue with this PR, but if we also want to support initial_prompt in the pipeline we'll have to decide on which approach we want. (It's not possible to have condition_on_previous_text in the current pipeline.)

@hollance
Copy link
Contributor

  • We can provide a prompt in the pipeline like the below without modifying the pipeline at all, works for me locally. Is this sufficient / what you had in mind?

You are correct that when you do the following,

pipe = pipeline(task="automatic-speech-recognition", model="openai/whisper-tiny")
res = pipe(samples, generate_kwargs={ "prompt_ids": prompt_ids })

the pipeline will automatically pass the prompt_ids to model.generate(). However note that this pipeline only processes the first 30 seconds of the audio file. This is fine for audio that is shorter than 30 seconds.

However, to process an audio file that is longer than 30 seconds, we have to do:

res = pipe(example, generate_kwargs={ "prompt_ids": prompt_ids }, chunk_length_s=30, stride_length_s=[6, 0])

Now the same prompt_ids are passed to model.generate() for each 30-second chunk. In effect, this is the always_use_initial_prompt option.

To get the regular initial_prompt (i.e. always_use_initial_prompt disabled) and condition_on_previous_text behavior as they work in OpenAI with the current pipeline, we'd have to pass in a stride_length_s=[0,0] and batch_size=1 to make the loop work sequentially rather than in parallel, and somehow keep track of the previous outputs.

@connor-henderson connor-henderson changed the title feat: Whisper initial prompting feat: Whisper prompting Apr 14, 2023
@connor-henderson
Copy link
Contributor Author

connor-henderson commented Apr 14, 2023

Ok the additional requested features are now added so I believe this is ready for re-review. Thank you for your comments!

However note that this pipeline only processes the first 30 seconds of the audio file. This is fine for audio that is shorter than 30 seconds... In effect, this is the always_use_initial_prompt option.

I think I’m missing something here as I’ve tried this on >1 min of audio in the below example where I also added a debug line to decode the tokens inside of the pipeline as they were generated, and it appears to be properly sequential. In any case, if we don’t want this I’ll remove condition_on_previous_text from the pipeline just lmk!

pipe = pipeline(task="automatic-speech-recognition", model="openai/whisper-tiny")
res = pipe(samples, generate_kwargs={ "condition_on_previous_text": True, "prompt_ids": prompt_ids })
# ['<|startofprev|><|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.<|endoftext|>']
# ["<|startofprev|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Nor is Mr. Quilter's manner less interesting than his matter.<|endoftext|>"]
# ["<|startofprev|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter.<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He tells us that at this festive season of the year with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind.<|endoftext|>"]
# ["<|startofprev|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind.<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"]
# ["<|startofprev|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Lennils, pictures are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a turkish bath. Next man<|endoftext|>"]
# ["<|startofprev|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a turkish bath. Next man<|startoftranscript|><|en|><|transcribe|><|notimestamps|> it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression.<|endoftext|>"]
# ["<|startofprev|> middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a turkish bath. Next man it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression.<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art and Mr. Quilter writes with equal lucidity.<|endoftext|>"]

The suggested modifications to model.generate() in this PR make it possible to have both initial_prompt and the condition_on_previous_text options as in OpenAI, but it would require the user to write their own processing loop to get the same results as OpenAI.

Aimed to address this with the new sequential loop over chunks of the input. Right now this way is incompatible with return_dict_in_generate=True as I wasn't sure how / if we'd still want to several ModelOutputs, looking for guidance here.


Also, there are hacks in a few places related to getting the id of the prompt start token and separating it from the prompt text ids. Would this be something we could add to the model or generation config?

Copy link
Contributor

@hollance hollance left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this feature, @connor-henderson . I think you were very inventive in coming up with a solution that allows us to use initial_prompt and condition_on_previous_text as in OpenAI. 😄

However, your implementation doesn't seem to fit in very well with the current design of Transformers. I'll let my colleagues at HF weigh in too, but it might be better to split this functionality as follows:

  1. Add the prompt_ids to model.generate() as in your earlier version of the PR. All this does is insert the prompt in the <|startofprev|> section. This doesn't give us the OpenAI functionality yet, it only adds <|startofprev|> support to the modeling and tokenizer code.

  2. Create a new pipeline that is specific to Whisper that works more like the OpenAI inference code does. The logic for managing the <|startofprev|> section then sits in the new pipeline's loop, not in the model.

(Perhaps step 2 could be a separate PR, to keep the complexity of these PRs down a bit.)

Comment on lines 1674 to 1684
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a loop in model.generate() is a clever solution to get the pipeline to work sequentially, but it's also a bit hacky. I don't think it's the right approach for Transformers.

Comment on lines 95 to 99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than returning the prompt_ids as a list of integers, it would be preferable to have them as a tensor. But even better, get_prompt_ids() should use the return_tensors argument just like tokenizer does, so that the caller can decide between numpy or torch tensors, or a list of integers.

@amyeroberts
Copy link
Contributor

cc'ing in @gante re generate

@connor-henderson
Copy link
Contributor Author

  1. Add the prompt_ids to model.generate() as in your earlier version of the PR. All this does is insert the prompt in the <|startofprev|> section. This doesn't give us the OpenAI functionality yet, it only adds <|startofprev|> support to the modeling and tokenizer code.

Thanks @hollance I definitely agree splitting this into >1 PR is ideal, have pushed back up code for number 1 above so this can just address that portion. It now implicitly does always_use_initial_prompt.

@connor-henderson
Copy link
Contributor Author

Curious if by adding return_tensors to get_prompt_ids you're setting up effectively doing condition_on_previous_text via cleverly feeding batches / prompts to model.generate() calls (i.e. the first chunk of a second model.generate call would use the text from the first chunk of the first model.generate call as a prompt and so on for each chunk in the batch), but that's more of a question for subsequent PRs

@hollance
Copy link
Contributor

The reason I asked for the return_tensors argument is that passing the prompt_ids into model.generate() as a torch.LongTensor instead of List[int] is more consistent with how we normally pass tokens into Transformers models. I understand that inside the model you might need turn it into a list anyway for the forced_decoder_ids, but that's really an internal implementation detail. When we generate, the output token sequence is also a Tensor, and so we can concat this to the previous prompt_ids to create the next one, etc. I hope that makes sense. :-)

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR LGTM as it is, thank you for the contribution 🙌

BTW, the code changes do not match the description at the top. From what I gathered in the comments, there will be a follow-up PR, correct? In that case, would you mind updating the PR, before I tag a core maintainer for a final review? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason behind this slicing? Intuitively it makes sense to me, but I'm curious to know if there is a reference behind this choice :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll leave a comment in the code too, this is done to match Whisper's implementation. I believe the reason they do the -1 is to make room for the first token to generate, and the reason they do // 2 is to halve it to share context space with a prefix if one is provided (which also gets halved). I don't believe there's prefix support yet in transformers so technically the // 2 isn't necessary at this point but I didn't want to confuse future work around that if it happens. There's a good clarification of prompt vs prefix here if it's of interest.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @connor-henderson, as I am using the prompting feature I noticed a bug for long prompts. It might be caused by the slicing, where it should be text_prompt_ids = text_prompt_ids[-(self.config.max_length // 2 - 1) :], to correctly account for the first token <|startofprev|>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Helene-Maxcici, feel free to open a new issue to track this bug, tagging myself (and optionally @connor-henderson). In particular, it would be super helpful to have a reproducible code snippet to emulate the behaviour locally. See the following page for details: https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#submitting-a-bug-related-issue-or-feature-request

Copy link
Contributor

@hollance hollance left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is shaping up nicely, @connor-henderson! I think this PR has the right amount of changes and then we can figure out how to do the sequential generation in a follow-up PR.

I've added a bunch of remarks and suggestions so we can make this fit as well into Transformers as possible. 😄

I'd also like to invite my colleagues @sanchit-gandhi and @ArthurZucker to have a look at these changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion is prompt_ids: Optional[torch.Tensor] = None but I'll let my colleagues weigh in too. @sanchit-gandhi @ArthurZucker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to this suggestion

Comment on lines +1639 to +1648
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I think this now supports the different ways that forced_decoder_ids may be passed in?

  1. Through model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=..., task=...)

  2. Through model.generate(input_features, forced_decoder_ids=forced_decoder_ids)

  3. Through model.generate(input_features, language=..., task=...)

It would be good if there are unit tests for these different methods.

Copy link
Contributor Author

@connor-henderson connor-henderson May 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe model.generate allows passing in task or language directly as in 3. above, but I've now added tests for the other two

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does allow that (and I think it might even be the preferred method now) but for some reason the language needs to be the token, such as "<|ja|>" rather than "ja".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante

Copy link
Contributor

@gante gante May 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@connor-henderson update on the language code: we now support passing the language token, the language code, or the language name. See this (very recent) PR :)

(not sure if this info has gotten to you, many conversations in parallel in this PR)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the language code change was @connor-henderson's most recent PR! This forced_generation_ids logic is in-place so that the code is backward compatible with our previous way of handling the langauge/task, where we either set it in the config as config.forced_decoder_ids, or explicitly as forced_decoder_ids to the generate method (see #21965 (comment))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha derp, I didn't look at the author 🙈 my bad!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I'm happy with this since token_ids is also used below in the call to super().decode(...).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree, moved this prompt removal code after that super().decode(...) call to _decode so this conversion isn't necessary

Comment on lines 593 to 598
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edge case: prompt_end_idx is not set when token_ids has length 1. Maybe rewrite it to this:

   if skip_special_tokens and isinstance(token_ids, list):
        prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
        if prompt_token_id in token_ids:
             for i in range(1, len(token_ids)):
                 if token_ids[i] in self.all_special_ids:
                     token_ids = token_ids[i:]
                     break            

Although perhaps it's easiest to check if the very first token is <|startofprev|> rather than doing prompt_token_id in token_ids?

Comment on lines 299 to 309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the same logic as in the regular tokenizer, maybe we can extract it into a shared helper function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved some of the functionality into a helper and left some for what I think is the right reusability/readability tradeoff, lmk if you think more should be abstracted

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can understand putting this into a free function but if it's only used in one class, we generally keep it as a member function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for sure, removed this change. i'd been moving the tests around and at one point I had two classes using this

Comment on lines 1410 to 1482
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test! I'd like to see a few more tests where you also change the forced_decoder_ids (see my comment above). The way the forced_decoder_ids get passed around is a bit brittle (due to the code for that changing a few times) and so we should make sure we have solid tests, since it's too easy for someone to change how this works and inadvertently break something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some tests for edge cases?

For example: processor.get_prompt_ids("") or processor.get_prompt_ids("<|startofprev|> Mr. <|startofprev|> Quilter")

Copy link
Contributor Author

@connor-henderson connor-henderson May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second will definitely confuse the model and decoding if they were passed to the current get_prompt_ids as is, would you prefer we strip the prompt start token or raise an error that it was included? I'll push up a change that strips it for now, lmk which you prefer and if you'd want to log a warning as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know what would be the best approach here, was just trying to think of things that might go wrong. ;-)

Perhaps raising an error on unexpected input is the best choice, but only if it doesn't add a lot of complexity to the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like they have their tiktoken package handle it and it raises an error if any special token is included, so will look to do the same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is most readable / simplest as one test with comments clarifying the cases, lmk if you want them split into separate unit tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's very readable but there's a potential issue: the model will keep state around, i.e. it will change the model.generation_config object with the new forced_decoder_ids and this may affect the next test. So I think it's better to instantiate a new model before each test.

Maybe it's also a good idea to test what happens after you do the following, just to make sure the code can handle both of these things being None:

model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a case for the above which involved a change in generate I'll call out below. I was aiming to order the tests to prevent conflicting state issues but you're right they're more brittle that way, split them into individual tests

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I wasn't able to fully understand the last comment - for testing the case when:

model.config.forced_decoder_ids = None 
model.generation_config.forced_decoder_ids = None

is this tested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, moving parts, we had a test explicitly for this when there we 5 test cases. Then we trimmed them, per this #22496 (comment) I changed the test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids test to use whisper-base.en and return_timestampt=True. I just tested it tho and realized that combination didn't actually set those attributes to None, so I updated the test to explicity set those two to None.

tl;dr it was tested, then wasn't, now is again

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went back to using 'in' check here instead of indexing to 0 in token_ids since it errors if empty

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I suggested putting it inside a separate if skip_special_tokens: is that the in operation needs to scan through the list, which is slow, and we can avoid it if skip_special_tokens is False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point i'll adapt it to check the index 0

Copy link
Contributor

@hollance hollance left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're slowly getting there. 😄 (The reason I'm being so nitpicky is that we're going to be changing a model that's already being used a lot, so we have to be very careful to make the right decisions.)

Comment on lines 1531 to 1532
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comment still refers to the old type hints.

Comment on lines 97 to 99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to strip out all special tokens / raise an error?

Not sure what OpenAI does here. In "condition on previous text mode" they don't include the <|startoftranscript|><|en|><|transcribe|><|notimestamps|> tokens when they put the previous text in the <|startofprev|> section (since that would be problematic). But I'm not sure if they also strip out the actual timestamps such as <|1.5|> etc, that would be worth looking into.

I think get_prompt_ids should not accept any tokens >= processor.tokenizer.eos_token, so we should either strip these out or raise an error. If we do want to allow timestamp tokens, then it should accept tokens > processor.tokenizer.all_special_ids[-1], since that's where the timestamp tokens begin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like they have their tiktoken package handle it and it raises an error if any special token is included, so will look to do the same

Copy link
Contributor Author

@connor-henderson connor-henderson May 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the tokens they raise an error on, so timestamps are included. transformers uses the same time_precision of 0.02, but notably even tho <|1.00|> is caught by OpenAI's special tokens check any number that doesn't have hundreths place precision like <|1.0|> isn't. Opted to implement catching any positive decimal number inside the special token brackets

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your solution is probably fine but a simpler approach would be to make sure no token has a value >= processor.tokenizer.eos_token. ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting, for timestamps and <|nospeech|> too? I get several low ids when trying to tokenize a timestamp

Screenshot 2023-05-04 at 9 35 47 AM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe related: #20225.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess our tokenizer only decodes timestamp tokens, but doesn't know how to encode them?

Timestamps start at 50364. Any token id higher than that is a timestamp token.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I suggested putting it inside a separate if skip_special_tokens: is that the in operation needs to scan through the list, which is slow, and we can avoid it if skip_special_tokens is False.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. I'd prefer to do the has_prompt check after checking for skip_special_tokens. (It's only a small thing and skip_special_tokens will be True most of the time anyway, so consider this a nitpick. ;-) )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's very readable but there's a potential issue: the model will keep state around, i.e. it will change the model.generation_config object with the new forced_decoder_ids and this may affect the next test. So I think it's better to instantiate a new model before each test.

Maybe it's also a good idea to test what happens after you do the following, just to make sure the code can handle both of these things being None:

model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done solely for handling the case where prompt_ids are passed in but the generation config and model config's forced decoder ids are both None. Its essentially just changing the order of operations so that we can cleanly check forced_decoder_ids is None and prompt_ids is not None to then add non-prompt forced decoder ids, none of the other functionality should change

@amyeroberts
Copy link
Contributor

@AvivSham Thanks for reporting and @connor-henderson thanks for investigating!

I think we're good to merge 👍

@amyeroberts amyeroberts merged commit 2acedf4 into huggingface:main May 19, 2023
@dgram0
Copy link

dgram0 commented May 20, 2023

Thank you so much for adding this! I've found that I occasionally get the following:

Traceback (most recent call last):
  File "G:\Conda\hfwhisper\lib\site-packages\transformers\models\whisper\modeling_whisper.py", line 1662, in generate
    return super().generate(
  File "G:\Conda\hfwhisper\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "G:\Conda\hfwhisper\lib\site-packages\transformers\generation\utils.py", line 1518, in generate
    return self.greedy_search(
  File "G:\Conda\hfwhisper\lib\site-packages\transformers\generation\utils.py", line 2345, in greedy_search
    next_token_logits = outputs.logits[:, -1, :]
IndexError: index -1 is out of bounds for dimension 1 with size 0

My workaround is to catch the exception and try again without the prompt_ids.

@hollance
Copy link
Contributor

Do you have a reproducible example for this @dgram0? That seems like a serious enough bug that needs investigating further.

@hollance
Copy link
Contributor

@Johnson-NLP

Is it possible to add 'initial_prompt' in the Fine-Tune code with a 'prompt_use_rate' to control how often to add prompts to the sentences in training sets?

Sounds like an interesting idea. Would you mind opening a new issue for this? Thanks!

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented May 22, 2023

To get prompting working with fine-tuning, we probably don't want to explicitly add 'prompted' examples per-se, but rather split longer examples up into shorter ones and feed them sequentially through the model, providing previous passages as 'context' to the model.

For example, if we had a training sample that looked like:

This is the first sentence. This is the second sentence. And finally, this is the third.

Currently what we do is feed it to the model all at once:

<|startoftranscript|> This is the first sentence. This is the second sentence. And finally, this is the third. <|endoftranscript|>

What we can do is feed the first sentence in:

<|startoftranscript|> This is the first sentence. <|endoftranscript|>

Then the second sentence, with the first sentence as context:

<|startofprev|> This is the first sentence.<|startoftranscript|> This is the second sentence. <|endoftranscript|>

And then the third, with both the first and second sentences as context:

<|startofprev|> This is the first sentence. This is the second sentence.<|startoftranscript|>  And finally, this is the third.<|endoftranscript|>

At inference time, we then just provide the "context" as our prompts:

<|startofprev|> This is the prompt.<|startoftranscript|> (model generates the rest)

See section 2.3 of the Whisper paper for an in-depth explanation as to how they achieve this during pre-training. We essentially want to do the same for fine-tuning.

For this to work, ideally we need an original sentence that is >> 30s in duration. That way when we split it up, we don't have super short examples that we feed to the model.

@connor-henderson connor-henderson deleted the whisper-prompting branch May 22, 2023 17:55
@dgram0
Copy link

dgram0 commented May 23, 2023

Do you have a reproducible example for this @dgram0? That seems like a serious enough bug that needs investigating further.

I'll try reproducing in a small toy example. It's reproducible on my side with the fine-tuned large private model I've been working with.

@dgram0
Copy link

dgram0 commented May 23, 2023

Do you have a reproducible example for this @dgram0? That seems like a serious enough bug that needs investigating further.

The following triggers the bug on the 13th iterations of the loop. (Usually, it takes a lot more iterations.)

from datasets import load_dataset, DatasetDict
from transformers import WhisperForConditionalGeneration, WhisperProcessor

it = iter(load_dataset("librispeech_asr", "all", split="test.other", streaming=True))
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
prompt = 'some text rich in domain specific vocabulary lives here'
past_prompts = ["I am from the cutter lying off the coast"]
while it:
  _ = [next(it) for x in range(3)]
  clip = next(it)
  input_features = processor(clip['audio']['array'], sampling_rate=clip['audio']['sampling_rate'], return_tensors="pt").input_features
  prompt_ids = processor.get_prompt_ids(prompt + ' - ' + ' - '.join(past_prompts))
  pred_ids = model.generate(input_features, language="english", task="transcribe", max_new_tokens=128, prompt_ids=prompt_ids)
  result = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()
  result_text = result.removesuffix('.')
  print(result_text)
  if result_text != '':
    past_prompts.append(result_text)
    if len(past_prompts) > 12:
      past_prompts = past_prompts[1:]

@connor-henderson
Copy link
Contributor Author

@dgram0 thanks for sharing, I was able to repro this. As far as its relation to prompting I think this is another case of prompt sensitivity as opposed to a bug, but it may still be of interest with regards to Whisper generally since its the same error message as issue #22682.

I noticed that joining the prompts by ' - ' was causing the model to start predicting chinese characters, and using '. ' instead did not lead to the error (at least through 30 loops, at that point I stopped testing). I did notice degraded predictions over time though since a period did not necessarily belong after each result, and every now and again a chinese char was still predicted so. I'd just be cautious about how prompts are chained together.

@dgram0
Copy link

dgram0 commented May 23, 2023

@connor-henderson It's a bit of a contrived example meant just to recreate the issue without having to loop too much and at the same time show what may be considered a normal use case. Even without it predicting non-English characters or words you'll eventually encounter the issue within a few hundred loops.

@dgram0
Copy link

dgram0 commented May 23, 2023

@dgram0 thanks for sharing, I was able to repro this. As far as its relation to prompting I think this is another case of prompt sensitivity as opposed to a bug, but it may still be of interest with regards to Whisper generally since its the same error message as issue #22682.

I noticed that joining the prompts by ' - ' was causing the model to start predicting chinese characters, and using '. ' instead did not lead to the error (at least through 30 loops, at that point I stopped testing). I did notice degraded predictions over time though since a period did not necessarily belong after each result, and every now and again a chinese char was still predicted so. I'd just be cautious about how prompts are chained together.

The following still joins the prompts using ' - ', doesn't allow non-English characters in the prompts, doesn't seem to predict Chinese characters, does a decent job of transcription, and still fails on the 144th loop.

from datasets import load_dataset, DatasetDict
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import re
import torch

it = iter(load_dataset("librispeech_asr", "all", split="test.other", streaming=True))
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_ = model.to(device)
prompt = 'Some text rich in domain specific vocabulary and example format lives here.'
past_prompts = ["I am from the cutter lying off the coast."]
while it:
  clip = next(it)
  input_features = processor(clip['audio']['array'], sampling_rate=clip['audio']['sampling_rate'], return_tensors="pt").input_features
  prompt_ids = processor.get_prompt_ids(prompt + ' - ' + ' - '.join(past_prompts))
  if device.type == 'cuda':
    input_features = input_features.cuda()
  pred_ids = model.generate(input_features, language="english", task="transcribe", max_new_tokens=128, prompt_ids=prompt_ids)
  result = processor.batch_decode(pred_ids, skip_special_tokens=True)[0].strip()
  result_text = re.sub(r"[^\u0000-\u05C0\u2100-\u214F]+$", "", result)
  print(result)
  if result_text != '':
    past_prompts.append(result_text)
    if len(past_prompts) > 12:
      past_prompts = past_prompts[1:]

@connor-henderson
Copy link
Contributor Author

Thanks @dgram0 in that case I think this is a bug, I opened an issue #23723 and PR #23724 for both this and another bug this made me realize where max_new_tokens isn't properly enforced when the prompt_ids length is too large. I think they both have the same root cause.

@hollance
Copy link
Contributor

hollance commented May 24, 2023

Thanks, @dgram0. Would you have time to look at this bug @connor-henderson, since you're most familiar with this code? If not, I can have a look.

EDIT: LOL, I'm way too slow. Should probably refresh my browser before commenting. Thanks for making these new issues, Connor. 😄

@hollance
Copy link
Contributor

@connor-henderson @sanchit-gandhi Hey, did we ever resolve the add_prefix_space issue?

If I do the following,

pipe = pipeline(task="automatic-speech-recognition", model="openai/whisper-tiny")
prompt_ids = pipe.tokenizer.get_prompt_ids("Hello, world!", return_tensors="pt")

I get the error,

TypeError: _batch_encode_plus() got an unexpected keyword argument 'add_prefix_space'

It works fine if I create a processor or tokenizer object by hand and call get_prompt_ids().

I seem to recall this issue came up before but not sure if anything was decided for it?

@connor-henderson
Copy link
Contributor Author

@hollance @versae I missed that just looked into it. Appears to be a difference with the slow tokenizer accepting add_prefix_space and the fast tokenizer not recognizing or applying it, opened an issue here: #23764

gojiteji pushed a commit to gojiteji/transformers that referenced this pull request Jun 5, 2023
* initial working additions

* clean and rename, add cond stripping initial prompt to decode

* cleanup, edit create_initial_prompt_ids, add tests

* repo consistency, flip order of conditional

* fix error, move the processor fn to the tokenizer

* repo consistency, update test ids to corresponding tokenizer

* use convert_tokens_to_ids not get_vocab...

* use actual conditional in generate

* make sytle

* initial address comments

* initial working add new params to pipeline

* first draft of sequential generation for condition_on_previous_text

* add/update tests, make compatible with timestamps

* make compatible with diff. input kwargs and max length

* add None check

* add temperature check

* flip temp check operand

* refocusing to prev pr scope

* remove the params too

* make style

* edits, move max length incorporating prompt to whisper

* address comments

* remove asr pipeline prompt decoding, fix indexing

* address comments (more tests, validate prompt)

* un-comment out tests (from debug)

* remove old comment

* address comments

* fix typo

* remove timestamp token from test

* make style

* cleanup

* copy method to fast tokenizer, set max_new_tokens for test

* prompt_ids type just pt

* address Amy's comments

* make style
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* initial working additions

* clean and rename, add cond stripping initial prompt to decode

* cleanup, edit create_initial_prompt_ids, add tests

* repo consistency, flip order of conditional

* fix error, move the processor fn to the tokenizer

* repo consistency, update test ids to corresponding tokenizer

* use convert_tokens_to_ids not get_vocab...

* use actual conditional in generate

* make sytle

* initial address comments

* initial working add new params to pipeline

* first draft of sequential generation for condition_on_previous_text

* add/update tests, make compatible with timestamps

* make compatible with diff. input kwargs and max length

* add None check

* add temperature check

* flip temp check operand

* refocusing to prev pr scope

* remove the params too

* make style

* edits, move max length incorporating prompt to whisper

* address comments

* remove asr pipeline prompt decoding, fix indexing

* address comments (more tests, validate prompt)

* un-comment out tests (from debug)

* remove old comment

* address comments

* fix typo

* remove timestamp token from test

* make style

* cleanup

* copy method to fast tokenizer, set max_new_tokens for test

* prompt_ids type just pt

* address Amy's comments

* make style
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]

output = model.generate(
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we allow passing prompt_ids as a numpy array here?

task=None,
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think prompt_ids should not be allowed to be a numpy array given its signature (see: https://github.com/huggingface/transformers/pull/22496/files#r1467369773)

mattt pushed a commit to mattt/transformers that referenced this pull request May 9, 2024
* initial working additions

* clean and rename, add cond stripping initial prompt to decode

* cleanup, edit create_initial_prompt_ids, add tests

* repo consistency, flip order of conditional

* fix error, move the processor fn to the tokenizer

* repo consistency, update test ids to corresponding tokenizer

* use convert_tokens_to_ids not get_vocab...

* use actual conditional in generate

* make sytle

* initial address comments

* initial working add new params to pipeline

* first draft of sequential generation for condition_on_previous_text

* add/update tests, make compatible with timestamps

* make compatible with diff. input kwargs and max length

* add None check

* add temperature check

* flip temp check operand

* refocusing to prev pr scope

* remove the params too

* make style

* edits, move max length incorporating prompt to whisper

* address comments

* remove asr pipeline prompt decoding, fix indexing

* address comments (more tests, validate prompt)

* un-comment out tests (from debug)

* remove old comment

* address comments

* fix typo

* remove timestamp token from test

* make style

* cleanup

* copy method to fast tokenizer, set max_new_tokens for test

* prompt_ids type just pt

* address Amy's comments

* make style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Whisper Prompting